[Kernel] Add FP8 KV cache support to Triton MLA decode attention#34597
[Kernel] Add FP8 KV cache support to Triton MLA decode attention#34597vllm-bot merged 1 commit intovllm-project:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
|
Documentation preview: https://vllm--34597.org.readthedocs.build/en/34597/ |
There was a problem hiding this comment.
Code Review
This pull request successfully enables FP8 KV cache support for the Triton MLA decode attention backend. The changes are well-implemented across the documentation, backend configuration, Triton kernels, and tests.
The key changes include:
- Updating the
TritonMLABackendto advertise support forfp8andfp8_e4m3KV cache data types. - Threading
k_scaleandv_scalethrough the decode attention call stack to the Triton kernels. - Implementing on-the-fly dequantization for FP8 tensors within the Triton kernels, which is efficient as it leverages compile-time checks.
- Adding a comprehensive set of parameterized tests to validate the FP8 implementation against a BF16 reference, using appropriate precision tolerances for FP8 arithmetic.
The implementation is robust, and the changes are consistent and correct. The code quality is high, and I have no major concerns. This is a solid contribution to improving performance on newer GPU architectures.
There was a problem hiding this comment.
Pull request overview
This pull request enables FP8 KV cache support for the Triton MLA (Multi-head Latent Attention) decode attention backend, which is the only MLA backend available on sm120 (Blackwell consumer) GPUs. The implementation uses Mode 1 FP8 (BF16 queries + FP8 KV cache) where FP8 tensors are dequantized on load inside the Triton kernels.
Changes:
- Added FP8 and FP8_e4m3 to the list of supported KV cache data types for TritonMLABackend
- Threaded k_scale and v_scale parameters through all decode attention kernel launch paths
- Implemented FP8 dequantization in both stage1 Triton kernels (standard MHA and grouped/MLA paths)
- Added comprehensive FP8-specific parametrized test cases with proper quantization and validation
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| vllm/v1/attention/ops/triton_decode_attention.py | Added k_scale/v_scale parameters and FP8 dequantization logic in both stage1 kernels; provides dummy 1.0 scales when None |
| vllm/v1/attention/backends/mla/triton_mla.py | Added fp8/fp8_e4m3 to supported dtypes, set supports_quant_query_input=False for Mode 1, and passed layer scales to kernel |
| tests/kernels/attention/test_triton_decode_attention.py | Added test_decode_attention_fp8 with 16 parametrized test cases covering various configurations |
| docs/design/attention_backends.md | Updated TRITON_MLA backend documentation to reflect new KV cache dtype support |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
@grimulkan thanks for your contribution! Could you please add end to end lm_eval results for any of the standard models that can run on SM120? |
I think the accuracy drop (~0.15 pts) is well within expected tolerance. Normalized generation speed (ignoring the potential 2x higher concurrency with fp8) is about the same as bf16, which is to be expected in this approach. Let me know if you need more tests like MMLU, etc. They take longer to run. By the way, I noticed |
|
I'm confirming that this is working on 8x RTX PRO AMD Turin: NCCL_P2P_LEVEL=SYS VLLM_LOG_STATS_INTERVAL=1 NCCL_GRAPH_FILE=/mnt/nccl_graph_opt.xml VLLM_TEST_FORCE_FP8_MARLIN=1 VLLM_MARLIN_USE_ATOMIC_ADD=1 VLARLIN_INPUT_DTYPE=fp8 vllm serve moonshotai/Kimi-K2.5 --served-model-name Kimi-K2.5 --trust-remote-code --host 0.0.0.0 --port 5000 --tensor-parallel-size 8 --pipeline-parallel-size 1 --enable-chunked-prefill --enable-prefix-caching --load-format fastsafetensors --tool-call-parser kimi_k2 --enable-auto-tool-choice --reasoning-parser kimi_k2 --async-scheduling --gpu-memory-utilization 0.93 --max-num-batched-tokens 4096 --mm-processor-cache-gb 0 --mm-encoder-tp-mode weights --language-model-only --attention-backend TRITON_MLA --kv-cache-dtype fp8 GPU KV cache size: 449,600 tokens when --decode-context-parallel-size 8 is used (more KV cache): speed: 66tok/sec |
|
Cross-posting these results here: Some speed/VRAM benchmarks on sm120. Kimi K2.5 on RTX 6000 Pro** (native int4 experts, Marlin gemm, Triton MLA)
All fp8 versions use this PR, and the DCP8 versions additionally use #34795. The unlocked KV cache savings are pretty huge. NOTE: Likely this PR needs to be rebased & features merged if #33529 is merged before this one. |
|
@LucasWilkinson For review |
Cherry-picked from: - PR vllm-project#34597: FP8 KV cache support for Triton MLA decode attention - PR vllm-project#34795: Enable FP8 KV cache with Decode Context Parallel (DCP) for MLA Changes: - Add fp8/fp8_e4m3 to TritonMLABackend.supported_kv_cache_dtypes - Thread k_scale/v_scale through decode attention kernel - Add FP8 dequant-on-load in Triton kernels - Enable DCP + FP8 KV cache combination - Add gather_and_maybe_dequant_cache for FP8 DCP prefill path
41eb69e to
198950d
Compare
|
Rebased, no change in performance or functionality. I experimented with The non-sm120 attention backends currently all set No change - just recording here for posterity. |
|
This pull request has merge conflicts that must be resolved before it can be |
198950d to
ecf8fd6
Compare
|
Rebased (only documentation conflict) |
|
This pull request has merge conflicts that must be resolved before it can be |
ecf8fd6 to
31403ca
Compare
MatthewBonanni
left a comment
There was a problem hiding this comment.
Overall looks good, thanks for doing this! Just a few small comments
31403ca to
9242fd0
Compare
Enable fp8/fp8_e4m3 KV cache for the Triton MLA attention backend, which is the only MLA backend available on sm120 GPUs. - Add fp8 and fp8_e4m3 to TritonMLABackend.supported_kv_cache_dtypes - Thread k_scale/v_scale through decode attention kernel launch path - Add FP8 dequant-on-load in both stage1 Triton kernels (MHA and grouped/MLA) - Set supports_quant_query_input=False for FP8 (BF16 queries + FP8 KV) - Add FP8-specific parametrized test cases Signed-off-by: grimulkan <grimulkan@gmail.com>
9242fd0 to
891fc11
Compare
✅ ROCm MI300X Verification — All Tests PassTested this PR on AMD Instinct MI300X (gfx942, ROCm 7.0.2) by patching into Unit Tests: 16/16 PASSED ✅E2E Serving: DeepSeek-V2-Lite (MLA) with FP8 KV cache ✅
ROCm Notes
Great work @grimulkan! 🎉 |
…m-project#34597) Signed-off-by: grimulkan <grimulkan@gmail.com> Signed-off-by: Athrael Soju <athrael.soju@gmail.com>
…m-project#34597) Signed-off-by: grimulkan <grimulkan@gmail.com>
…m-project#34597) Signed-off-by: grimulkan <grimulkan@gmail.com>
…m-project#34597) Signed-off-by: grimulkan <grimulkan@gmail.com>
Enable fp8/fp8_e4m3 KV cache for the Triton MLA attention backend, which is the only MLA backend available on sm120 GPUs.
Purpose
Enable FP8 KV cache for MLA models on sm120 (Blackwell consumer GPUs). The Triton MLA backend is the only available MLA backend on sm120, but previously blocked FP8 with
NotImplementedError.Changes
"fp8"and"fp8_e4m3"toTritonMLABackend.supported_kv_cache_dtypesk_scale/v_scalethrough the decode attention kernel launch pathsupports_quant_query_input=Falsefor (BF16 queries + FP8 KV cache)Test Plan
pytest tests/kernels/attention/test_triton_decode_attention.pyvllm serveon a compatible model using Triton MLA backendTest Results
tests/kernels/attention/test_triton_decode_attention.py~0.15 pts within expected tolerance. Normalized generation speed (ignoring the potential 2x higher concurrency with fp8) is about the same as bf16, which is to be expected in this approach.
Known limitations
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.